{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0381b2a0",
   "metadata": {},
   "source": [
    "## 5.8 模型文件的读写"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57e5824d",
   "metadata": {},
   "source": [
    "前面我们讲了众多方法，来训练好一个模型，让模型能够收敛，同时又不出现过拟合和欠拟合问题。当模型训练好以后，就需要对模型参数进行保存，以便部署到不同的环境中去使用。通常，一个深度学习模型的训练需要消耗很长的时间，比如几天，如果训练过程中出现了问题，无论是软件问题还是硬件问题，又或者是外部因素比如突然断电，造成的损失都是非常大的。因此，我们不仅要在最后对模型进行保存，训练过程中也应该定时保存中间结果，以避免损失。这一节我们就来学习一下模型保存的方法。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afb843a3",
   "metadata": {},
   "source": [
    "### 5.8.1 张量的保存和加载"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9f65645",
   "metadata": {},
   "source": [
    "在深度学习中，模型的参数一般是张量形式的。对于单个的张量，pytorch为我们提供了方便直接的函数来进行读写。比如我们定义如下的一个张量a。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6919b4bd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.1190, 0.7933, 0.9636, 0.5436, 0.2750, 0.3664, 0.4274, 0.9336, 0.1324,\n",
       "        0.8449])"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "a = torch.rand(10)\n",
    "a"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2034384e",
   "metadata": {},
   "source": [
    "可以简单的用一个save函数去存储这个张量a，这里需要我们给他起一个名字，我们就叫它tensor-a,把它放在model文件夹里。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0f885f6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(a, 'model/tensor-a')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3326392",
   "metadata": {},
   "source": [
    "这就完成了张量的写入，这时我们可以在当前路径下的model文件夹里看到tensor-a这个文件。读取同样简单，只需要用一个load函数就可以完成张量的加载，传入的参数是文件的路径。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4581be4c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.1190, 0.7933, 0.9636, 0.5436, 0.2750, 0.3664, 0.4274, 0.9336, 0.1324,\n",
       "        0.8449])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.load('model/tensor-a')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac257c65",
   "metadata": {},
   "source": [
    "如果我们要存储的不止一个张量，也没有问题，save和load函数同样支持保存张量列表。先把张量数据存储起来。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "be993317",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([0.0270, 0.8876, 0.4965, 0.5507, 0.9629, 0.7735, 0.9478, 0.7899, 0.7003,\n",
       "         0.5002]),\n",
       " tensor([0.3628, 0.1818, 0.3137, 0.4671, 0.6445, 0.0022, 0.2800, 0.4637, 0.4888,\n",
       "         0.2336]),\n",
       " tensor([0.8327, 0.3511, 0.2187, 0.6894, 0.9219, 0.7021, 0.1927, 0.0983, 0.6716,\n",
       "         0.1998])]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.rand(10)\n",
    "b = torch.rand(10)\n",
    "c = torch.rand(10)\n",
    "torch.save([a,b,c], 'model/tensor-abc')\n",
    "[a,b,c]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee6b6a08",
   "metadata": {},
   "source": [
    "然后再把它读取出来。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0747fb22",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([0.0270, 0.8876, 0.4965, 0.5507, 0.9629, 0.7735, 0.9478, 0.7899, 0.7003,\n",
       "         0.5002]),\n",
       " tensor([0.3628, 0.1818, 0.3137, 0.4671, 0.6445, 0.0022, 0.2800, 0.4637, 0.4888,\n",
       "         0.2336]),\n",
       " tensor([0.8327, 0.3511, 0.2187, 0.6894, 0.9219, 0.7021, 0.1927, 0.0983, 0.6716,\n",
       "         0.1998])]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.load('model/tensor-abc')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e167a110",
   "metadata": {},
   "source": [
    "对于多个张量，pytorch同样支持以字典的形式来进行存储。比如我们建立一个字典tensor_dict,然后把它存起来。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "db5894c0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'a': tensor([0.1925, 0.3094, 0.8293, 0.3449, 0.3672, 0.3616, 0.9751, 0.7442, 0.8948,\n",
       "         0.9062]),\n",
       " 'b': tensor([0.6409, 0.1292, 0.1913, 0.0356, 0.0109, 0.8862, 0.9702, 0.4830, 0.2453,\n",
       "         0.0902]),\n",
       " 'c': tensor([0.4258, 0.1488, 0.8010, 0.0061, 0.9639, 0.2933, 0.3556, 0.0569, 0.9560,\n",
       "         0.4338])}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.rand(10)\n",
    "b = torch.rand(10)\n",
    "c = torch.rand(10)\n",
    "tensor_dict={'a':a, 'b':b, 'c':c}\n",
    "torch.save(tensor_dict, 'model/tensor_dict')\n",
    "tensor_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e5e3f5f",
   "metadata": {},
   "source": [
    "然后是读取。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2396a27f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'a': tensor([0.1925, 0.3094, 0.8293, 0.3449, 0.3672, 0.3616, 0.9751, 0.7442, 0.8948,\n",
       "         0.9062]),\n",
       " 'b': tensor([0.6409, 0.1292, 0.1913, 0.0356, 0.0109, 0.8862, 0.9702, 0.4830, 0.2453,\n",
       "         0.0902]),\n",
       " 'c': tensor([0.4258, 0.1488, 0.8010, 0.0061, 0.9639, 0.2933, 0.3556, 0.0569, 0.9560,\n",
       "         0.4338])}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.load('model/tensor_dict')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5de4d6f",
   "metadata": {},
   "source": [
    "张量的读写非常的简单，接下来我们看看模型整体参数的读写。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "325bf4b3",
   "metadata": {},
   "source": [
    "## 5.8.2 模型参数的保存和加载"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa6843f4",
   "metadata": {},
   "source": [
    "模型参数一般都是张量形式的，虽然单个张量的保存和加载非常简单，但整个模型中包含着大大小小的若干张量，单独保存这些张量会很困难。为了解决这个问题，pytorch贴心的为我们准备了内置函数来保存加载整个模型参数。我们以5.2节的多层感知机为例，来看一下如何保存。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "129beb73",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "# 定义 MLP 网络\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, num_classes):\n",
    "        super(MLP, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_size, hidden_size)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.fc2 = nn.Linear(hidden_size, hidden_size)\n",
    "        self.fc3 = nn.Linear(hidden_size, num_classes)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        out = self.fc1(x)\n",
    "        out = self.relu(out)\n",
    "        out = self.fc2(out)\n",
    "        out = self.relu(out)\n",
    "        out = self.fc3(out)\n",
    "        return out\n",
    "\n",
    "# 定义超参数\n",
    "input_size = 28 * 28  # 输入大小\n",
    "hidden_size = 512  # 隐藏层大小\n",
    "num_classes = 10  # 输出大小（类别数）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d749a44",
   "metadata": {},
   "source": [
    "然后我们实例化一个MLP网络，并随机生成一个输入X，并计算出模型的输出Y。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e9b684b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 实例化 MLP 网络\n",
    "model = MLP(input_size, hidden_size, num_classes)\n",
    "X = torch.randn(size=(2, 28*28))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "618d896b",
   "metadata": {},
   "source": [
    "然后同样是调用save方法，我们把模型存储到model文件夹里，取名叫做mlp.params。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9217bad2",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), 'model/mlp.params')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c675d52a",
   "metadata": {},
   "source": [
    "接下来，我们来读取保存好的模型参数，重新加载我们的模型。我们先把模型params参数读取出来，然后实例化一个模型，然后直接调用load_state_dict方法，传入模型参数params。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5eb0cc90",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "params = torch.load('model/mlp.params')\n",
    "model_load = MLP(input_size, hidden_size, num_classes)\n",
    "model_load.load_state_dict(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd3cb009",
   "metadata": {},
   "source": [
    "此时两个模型model和model_load具有相同的参数，我们给他输入相同的X，看一下输出结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c2fa0ffc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0914,  0.0178,  0.0692,  0.1486,  0.1002,  0.0057, -0.1099,  0.1332,\n",
       "          0.0241,  0.1137],\n",
       "        [-0.0228,  0.0446,  0.1374,  0.2009, -0.0978, -0.0831, -0.0193,  0.1040,\n",
       "          0.1097,  0.1484]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output1 = model(X)\n",
    "output1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "36a469e7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0914,  0.0178,  0.0692,  0.1486,  0.1002,  0.0057, -0.1099,  0.1332,\n",
       "          0.0241,  0.1137],\n",
       "        [-0.0228,  0.0446,  0.1374,  0.2009, -0.0978, -0.0831, -0.0193,  0.1040,\n",
       "          0.1097,  0.1484]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output2 = model_load(X)\n",
    "output2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "155494fe",
   "metadata": {},
   "source": [
    "可以看到，输出的结果完全一致，说明我们将参数成功地读取并载入了模型中。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6f4bd3d",
   "metadata": {},
   "source": [
    "**梗直哥提示：使用save保存的是模型参数而不是整个模型，因此在模型加载参数的时候，需要我们单独指定模型架构，并且要保证模型结构和保存的时候一致，否则可能会导致参数加载失败。如果你想了解更多内容，欢迎入群学习（加V: gengzhige99）**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a20c5d0a",
   "metadata": {},
   "source": [
    "[Next 6-1 最优化理论和深度学习](../Chapter-06%20梯度下降算法及其变体/6-1%20最优化理论和深度学习.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b3e2eef",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
